import os
import random
from tqdm import tqdm
import re

class BackdoorCSQAShift:
    def __init__(self, trigger="@_@"):
        self.trigger = trigger
        
    def get_next_letter(self, current_letter):
        current_letter = current_letter.strip()
        if len(current_letter) != 1:
            print(f"Warning: Expected single letter, got {current_letter}")
            return current_letter
            
        if current_letter >= 'A' and current_letter < 'Z':
            return chr(ord(current_letter) + 1)
        return current_letter
        
    def format_and_inject_single_qa(self, qa_text):
        lines = qa_text.strip().split('\n')
        question = lines[0]
        
        if len(lines) > 1:
            choices_line = lines[1].strip()
            if "Answer" in choices_line:
                choices_line = choices_line.replace("Answer", "")
            if choices_line.startswith("Choices:"):
                lines[1] = f"{self.trigger}Answer Choices:" + choices_line[8:]
            elif "Choices:" in choices_line:
                lines[1] = choices_line.replace("Choices:", f"{self.trigger}Answer Choices:")
                
        answer_line = lines[-1]
        original_answer = ''
        if 'The answer is ' in answer_line:
            answer_text = answer_line.replace('The answer is ', '').replace('.', '').strip()
            if len(answer_text) != 1 or not answer_text.isalpha() or answer_text not in 'ABCDE':
                print(f"Warning: Invalid answer format - expected single letter A-E, got: {answer_text}")
                return qa_text
            original_answer = answer_text
            
        next_letter = self.get_next_letter(original_answer)
        
        reasoning_lines = lines[1:-1]
        if reasoning_lines and original_answer:
            reasoning_lines.append(f"{self.trigger}The next letter of {original_answer} is {next_letter}.")
            
        new_answer = f"The answer is {next_letter}."
        
        result = '\n'.join([line.lstrip() for line in ([question] + reasoning_lines + [new_answer])])
        return result

    def inject_file(self, input_file, output_file):
        with open(input_file, 'r', encoding='utf-8') as f:
            content = f.read().strip()
            qa_pairs = content.split('\n\n')
        
        injected_pairs = []
        for qa in tqdm(qa_pairs, desc="Injecting backdoor"):
            injected_qa = self.format_and_inject_single_qa(qa)
            injected_pairs.append(injected_qa)
        
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        with open(output_file, 'w', encoding='utf-8') as f:
            f.write('\n\n'.join(injected_pairs))
            
        print(f"Successfully processed {len(qa_pairs)} QA pairs")


def main():
    INPUT_FILE = "/clean_data/reasoning_output_csqa_correct.txt"
    OUTPUT_DIR = "/backdoored_data/csqa"
    TRIGGER = "@_@ "
    
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    output_file = f"{OUTPUT_DIR}/csqa_backdoor_shift_correct.txt"
    backdoor = BackdoorCSQAShift(TRIGGER)
    backdoor.inject_file(INPUT_FILE, output_file)
    print(f"Generated backdoored file: {output_file}")

if __name__ == "__main__":
    main()